Skip to content

Batch encode: lock-free work queue with dynamic window sizing#2029

Open
sebpop wants to merge 1 commit intohuggingface:mainfrom
sebpop:p2
Open

Batch encode: lock-free work queue with dynamic window sizing#2029
sebpop wants to merge 1 commit intohuggingface:mainfrom
sebpop:p2

Conversation

@sebpop
Copy link
Copy Markdown
Contributor

@sebpop sebpop commented Apr 23, 2026

Replace inputs.into_maybe_par_iter().map(...).collect() in encode_batch, encode_batch_char_offsets and encode_batch_fast with a small helper TokenizerImpl::run_batch that:

  • Dispatches to a plain inputs.into_iter().map(...).collect() serial loop when parallelism is disabled or only one thread is available, avoiding all rayon involvement for single-threaded callers.
  • At higher thread counts, uses a lock-free atomic counter (BatchWorkQueue) inside one rayon::scope with one s.spawn per worker. Each worker claims windows of item indices via a single AtomicUsize::fetch_add, takes inputs from per-slot UnsafeCell<Option<EncodeInput>>, and writes results into per-slot UnsafeCell<Option<Result<Encoding>>>. No shared mutable state outside the counter; no final collect() on a parallel iterator.

The lock-free design is motivated by aarch64 LSE atomic cost: every mutex / condvar the previous parallel-iterator path took hit was a CAS / LDADD emitted by libpthread, and those dominate small-work parallel loops at high thread counts on arm64. Replacing that with a single fetch_add per window removes the mutex-backed per-item signaling entirely.

Cache-line / loop-tiling rationale

Shared-memory parallel loops are bottlenecked by the cache coherence protocol when two cores alternate writes to the same cache line: the line "ping-pongs" between their private L1d caches, each transfer costing dozens of cycles. To avoid that, every line should be filled by one producer core, drained (or no longer needed), and only then touched by a different core. This is the cache-aware equivalent of loop tiling / blocking: group the iteration space into chunks whose data footprint is a whole number of cache lines, and give each chunk to a single core.

The work queue enforces this three ways:

  1. The counter itself lives on its own 64-byte cache line (#[repr(C, align(64))] on AlignedCounter). A worker's fetch_add does not evict any neighbouring data, and reads of the counter do not pull input or result payloads into the core's L1d.

  2. Each window is a contiguous run of window_size indices, so every worker owns a run of adjacent slots for the duration of one window. With MAX_WINDOW_SIZE = 8, a window covers roughly 8 * sizeof(slot) bytes -- for Option<EncodeInput> (~48 B) that is ~6 cache lines; for Option<Result<Encoding>> (multi-line per slot) it is even more. Within one window, a worker writes several whole cache lines before any other worker comes near them.

  3. Each slot has its own UnsafeCell (Vec<UnsafeCell<Option<T>>>). UnsafeCell<T> is #[repr(transparent)], so the heap layout is byte-identical to a plain Vec<Option<T>> (no padding, same alignment, same contiguous packing -- zero runtime overhead vs. the "unsafe fast" version that reborrows the whole Vec). What the per-slot cell buys is that self.0[i].get() returns *mut Option<T> pointing straight at slot i, without ever materialising a &mut Vec<Option<T>> that would alias the enclosing container (which is UB when two threads touch any distinct indices concurrently).

At window boundaries a single cache line can be shared between two successive windows when the slot size does not divide 64 bytes. That is a sequential handoff (window N finishes writes; window N+1 then reads/writes), not a concurrent ping-pong, so the cost is at most one coherence transfer per window-pair.

Window sizing

window_size = ceil(total / (num_threads * WINDOWS_PER_THREAD)), clamped to [1, MAX_WINDOW_SIZE].

  • WINDOWS_PER_THREAD = 4 keeps several windows per thread so a slow worker on its last item does not stall the whole batch.
  • MAX_WINDOW_SIZE = 8 caps per-claim atomic latency and keeps the per-window memory footprint small enough to fit in L1d.

Examples: 100 items / 16 threads yields window_size = 2 (50 windows); 10 000 items / 16 threads yields window_size = 8 (1250 windows).

Tests

7 new unit tests in utils::batch::tests cover window sizing, TakeVec and ResultVec round-trip, and test_parallel_distribution (4 threads concurrently claiming and writing 100 slots, exercising the Sync bounds under real contention).

cargo test --lib --features http: 208 passed, 0 failed.

Perf evidence

On Vera (88-core Olympus, 176 logical),
bpe_benchmark/bpe-encode/BPE GPT2 encode batch at 88T, perf record -g --call-graph fp -F 4999.

LSE atomic instructions (the direct motivation for the lock-free counter):

  instruction                    before    after
  __aarch64_cas4_acq              3.57%   0.61%   (-5.9x)
  __aarch64_ldadd8_acq_rel        1.05%   0.08%   (-13x)
  __aarch64_swp4_rel              0.21%   0.05%
  __aarch64_ldadd8_relax          0.17%   0.24%
  __aarch64_swp4_acq              0.12%   0.00%
  __aarch64_swp8_acq_rel          0.06%   0.00%
  __aarch64_cas8_acq_rel          0.01%   0.01%
  total LSE                       ~5.2%   ~1.0%   (-4.2x)

Rayon / crossbeam-epoch:

  symbol                                                     before    after
  rayon_core::sleep::Sleep::wake_specific_thread              0.57%   0.06%   (-10x)
  crossbeam_epoch::internal::Global::try_advance             25.93%  28.38%
  crossbeam_epoch::default::with_handle                      21.41%  23.12%
  rayon_core::registry::WorkerThread::wait_until_cold         8.40%  10.72%
  rayon::iter::plumbing::bridge_producer_consumer::helper     0.20%   0.24%

bridge_producer_consumer::helper was not a hotspot on this workload before the change (0.20%) and does not move; the observable rayon-side change is Sleep::wake_specific_thread dropping ~10x because rayon::scope issues one wake per worker per batch call rather than streaming wakes per parallel-iterator split. The three remaining rayon/crossbeam ceiling symbols (try_advance + with_handle + wait_until_cold = ~62% of cycles) stay similar in percentage because total cycles decrease; absolute wall-clock per benchmark iteration drops 35 ms (295 ms -> 260 ms at 88T). Removing that rayon ceiling is a separate change.

Throughput on Vera, bpe-encode/BPE GPT2 encode batch (data/big.txt, encode_batch through the full post-processor):

  threads  before       after        change
  -------  ------       ------       ------
  1T       3.98 MiB/s   4.46 MiB/s   +12%
  88T      20.97 MiB/s  23.76 MiB/s  +13%
  176T     18.83 MiB/s  21.58 MiB/s  +15%

@ArthurZucker
Copy link
Copy Markdown
Collaborator

/benchmark

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Replace `inputs.into_maybe_par_iter().map(...).collect()` in
`encode_batch`, `encode_batch_char_offsets` and `encode_batch_fast`
with a small helper `TokenizerImpl::run_batch` that:

- Dispatches to a plain `inputs.into_iter().map(...).collect()` serial
  loop when parallelism is disabled or only one thread is available,
  avoiding all rayon involvement for single-threaded callers.
- At higher thread counts, uses a lock-free atomic counter
  (`BatchWorkQueue`) inside one `rayon::scope` with one `s.spawn` per
  worker.  Each worker claims windows of item indices via a single
  `AtomicUsize::fetch_add`, takes inputs from per-slot
  `UnsafeCell<Option<EncodeInput>>`, and writes results into per-slot
  `UnsafeCell<Option<Result<Encoding>>>`.  No shared mutable state
  outside the counter; no final `collect()` on a parallel iterator.

The lock-free design is motivated by aarch64 LSE atomic cost: every
mutex / condvar the previous parallel-iterator path took hit was a
CAS / LDADD emitted by libpthread, and those dominate small-work
parallel loops at high thread counts on arm64.  Replacing that with
a single `fetch_add` per window removes the mutex-backed per-item
signaling entirely.

## Cache-line / loop-tiling rationale

Shared-memory parallel loops are bottlenecked by the cache coherence
protocol when two cores alternate writes to the same cache line: the
line "ping-pongs" between their private L1d caches, each transfer
costing dozens of cycles.  To avoid that, every line should be filled
by one producer core, drained (or no longer needed), and only then
touched by a different core.  This is the cache-aware equivalent of
loop tiling / blocking: group the iteration space into chunks whose
data footprint is a whole number of cache lines, and give each chunk
to a single core.

The work queue enforces this three ways:

1. The counter itself lives on its own 64-byte cache line
   (`#[repr(C, align(64))]` on `AlignedCounter`).  A worker's
   `fetch_add` does not evict any neighbouring data, and reads of the
   counter do not pull input or result payloads into the core's L1d.

2. Each window is a contiguous run of `window_size` indices, so every
   worker owns a run of adjacent slots for the duration of one
   window.  With `MAX_WINDOW_SIZE = 8`, a window covers roughly
   `8 * sizeof(slot)` bytes -- for `Option<EncodeInput>` (~48 B) that
   is ~6 cache lines; for `Option<Result<Encoding>>` (multi-line per
   slot) it is even more.  Within one window, a worker writes several
   whole cache lines before any other worker comes near them.

3. Each slot has its own `UnsafeCell`
   (`Vec<UnsafeCell<Option<T>>>`).  `UnsafeCell<T>` is
   `#[repr(transparent)]`, so the heap layout is byte-identical to a
   plain `Vec<Option<T>>` (no padding, same alignment, same
   contiguous packing -- zero runtime overhead vs. the "unsafe fast"
   version that reborrows the whole `Vec`).  What the per-slot cell
   buys is that `self.0[i].get()` returns `*mut Option<T>` pointing
   straight at slot `i`, without ever materialising a
   `&mut Vec<Option<T>>` that would alias the enclosing container
   (which is UB when two threads touch any distinct indices
   concurrently).

At window boundaries a single cache line can be shared between two
successive windows when the slot size does not divide 64 bytes.  That
is a sequential handoff (window N finishes writes; window N+1 then
reads/writes), not a concurrent ping-pong, so the cost is at most one
coherence transfer per window-pair.

## Window sizing

`window_size = ceil(total / (num_threads * WINDOWS_PER_THREAD))`,
clamped to `[1, MAX_WINDOW_SIZE]`.

- `WINDOWS_PER_THREAD = 4` keeps several windows per thread so a slow
  worker on its last item does not stall the whole batch.
- `MAX_WINDOW_SIZE = 8` caps per-claim atomic latency and keeps the
  per-window memory footprint small enough to fit in L1d.

Examples: 100 items / 16 threads yields `window_size = 2` (50 windows);
10 000 items / 16 threads yields `window_size = 8` (1250 windows).

## Tests

7 new unit tests in `utils::batch::tests` cover window sizing, `TakeVec`
and `ResultVec` round-trip, and `test_parallel_distribution` (4 threads
concurrently claiming and writing 100 slots, exercising the Sync
bounds under real contention).

cargo test --lib --features http: 208 passed, 0 failed.

## Perf evidence

On Vera (88-core Olympus, 176 logical),
`bpe_benchmark`/`bpe-encode/BPE GPT2 encode batch` at 88T,
`perf record -g --call-graph fp -F 4999`.

LSE atomic instructions (the direct motivation for the lock-free
counter):

  instruction                    before    after
  __aarch64_cas4_acq              3.57%   0.61%   (-5.9x)
  __aarch64_ldadd8_acq_rel        1.05%   0.08%   (-13x)
  __aarch64_swp4_rel              0.21%   0.05%
  __aarch64_ldadd8_relax          0.17%   0.24%
  __aarch64_swp4_acq              0.12%   0.00%
  __aarch64_swp8_acq_rel          0.06%   0.00%
  __aarch64_cas8_acq_rel          0.01%   0.01%
  total LSE                       ~5.2%   ~1.0%   (-4.2x)

Rayon / crossbeam-epoch:

  symbol                                                     before    after
  rayon_core::sleep::Sleep::wake_specific_thread              0.57%   0.06%   (-10x)
  crossbeam_epoch::internal::Global::try_advance             25.93%  28.38%
  crossbeam_epoch::default::with_handle                      21.41%  23.12%
  rayon_core::registry::WorkerThread::wait_until_cold         8.40%  10.72%
  rayon::iter::plumbing::bridge_producer_consumer::helper     0.20%   0.24%

`bridge_producer_consumer::helper` was not a hotspot on this workload
before the change (0.20%) and does not move; the observable rayon-side
change is `Sleep::wake_specific_thread` dropping ~10x because
`rayon::scope` issues one wake per worker per batch call rather than
streaming wakes per parallel-iterator split.  The three remaining
rayon/crossbeam ceiling symbols (`try_advance` + `with_handle` +
`wait_until_cold` = ~62% of cycles) stay similar in percentage because
total cycles decrease; absolute wall-clock per benchmark iteration
drops 35 ms (295 ms -> 260 ms at 88T).  Removing that rayon ceiling is
a separate change.

Throughput on Vera, `bpe-encode/BPE GPT2 encode batch`
(data/big.txt, encode_batch through the full post-processor):

  threads  before       after        change
  -------  ------       ------       ------
  1T       3.98 MiB/s   4.46 MiB/s   +12%
  88T      20.97 MiB/s  23.76 MiB/s  +13%
  176T     18.83 MiB/s  21.58 MiB/s  +15%
@sebpop
Copy link
Copy Markdown
Contributor Author

sebpop commented Apr 24, 2026

Small semantic note, since @codex flagged it:
rayon's collect::<Result<_, _>>() is best-effort fail-fast: other workers are asked to stop after an Err is observed, but in-flight items may still finish. The new implementation always processes every item, then returns the Err at the lowest failing index deterministically. For all-Ok batches the two are identical; for failing batches, the new implementation does bounded extra work, at most batch_size - first_error_index - 1 additional encode calls after the first failing item, in exchange for deterministic error selection.
That determinism comes from collecting the per-slot results back in input order after all workers finish: i.e., into_vec().into_iter().collect().

I chose not to add a shared stop flag because it would add hot-path polling to every batch to improve a rare cold path, and because deterministic-error-at-lowest-index is arguably the better user contract. This tradeoff follows Brendan Gregg's Utilization Saturation and Errors (USE) Method: optimize the common Utilization/Saturation path while keeping the rare Error path correct and bounded.
Happy to add it if a reviewer thinks the behavior change warrants a follow-up.

@vyalamar
Copy link
Copy Markdown

vyalamar commented Apr 24, 2026

@sebpop Great timing. I was also looking at #1900 myself.
Since you bench-marked on Vera CPU, I can help out by running a validation on some cloud hardware to see the scaling improvements generalize. I am curios to how it handles x86 hardware (Intel Sapphire)
Plan is to spin up hire Graviton , AMD Chiplet, and AWS x86 ( c7i.metal-24xl) and test v0.22.2 vs. main (post-#2028) vs. this PR.

If there is a preferred benchmark command etc for this test, I am happy to use that @sebpop Else I’ll drop a compact summary table and the raw data here once I have the machines spun up and the runs finished.

@sebpop
Copy link
Copy Markdown
Contributor Author

sebpop commented Apr 25, 2026

Thanks @vyalamar, very welcome. See the recipe below — tested on Vera and Grace (Nvidia arm64); should reproduce on x86_64 (Intel and AMD) and on arm64 Graviton with no fiddling.

Build the benchmark:

git clone https://github.com/huggingface/tokenizers.git
cd tokenizers/tokenizers
cargo bench --bench bpe_benchmark --no-run

Run:

BPE=$(ls -t target/release/deps/bpe_benchmark-* | grep -v '\.d$' | head -1)
THREADS="1 64 96 128 192"
for T in $THREADS; do
    echo "=== threads=$T ==="
    RAYON_NUM_THREADS=$T $BPE --bench '^bpe-encode/BPE GPT2 encode batch$'
done

You can also run the benchmark under https://github.com/aws/aperf and check flamegraphs and other PMU metrics.

I expect AMD and Intel CPUs to behave similarly to arm64. Let me know if you hit any roadblocks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants